{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Chapter 2: Our First Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.utils.data\n",
    "import torch.nn.functional as F\n",
    "import torchvision\n",
    "from torchvision import transforms\n",
    "from PIL import Image\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setting up DataLoaders\n",
    "\n",
    "We'll use the built-in dataset of `torchvision.datasets.ImageFolder` to quickly set up some dataloaders of downloaded cat and fish images. \n",
    "\n",
    "`check_image`  is a quick little function that is passed to the `is_valid_file` parameter in the ImageFolder and will do a sanity check to make sure PIL can actually open the file. We're going to use this in lieu of cleaning up the downloaded dataset.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_image(path):\n",
    "    try:\n",
    "        im = Image.open(path)\n",
    "        return True\n",
    "    except:\n",
    "        return False"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Set up the transforms for every image:\n",
    "\n",
    "* Resize to 64x64\n",
    "* Convert to tensor\n",
    "* Normalize using ImageNet mean & std\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "img_transforms = transforms.Compose([\n",
    "    transforms.Resize((64,64)),    \n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
    "                    std=[0.229, 0.224, 0.225] )\n",
    "    ])\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data_path = \"./train/\"\n",
    "train_data = torchvision.datasets.ImageFolder(root=train_data_path,transform=img_transforms, is_valid_file=check_image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "val_data_path = \"./val/\"\n",
    "val_data = torchvision.datasets.ImageFolder(root=val_data_path,transform=img_transforms, is_valid_file=check_image)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_data_path = \"./test/\"\n",
    "test_data = torchvision.datasets.ImageFolder(root=test_data_path,transform=img_transforms, is_valid_file=check_image) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size=64"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size)\n",
    "val_data_loader  = torch.utils.data.DataLoader(val_data, batch_size=batch_size) \n",
    "test_data_loader  = torch.utils.data.DataLoader(test_data, batch_size=batch_size) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Our First Model, SimpleNet\n",
    "\n",
    "SimpleNet is a very simple combination of three Linear layers and ReLu activations between them. Note that as we don't do a `softmax()` in our `forward()`, we will need to make sure we do it in our training function during the validation phase."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class SimpleNet(nn.Module):\n",
    "\n",
    "    def __init__(self):\n",
    "        super(SimpleNet, self).__init__()\n",
    "        self.fc1 = nn.Linear(12288, 84)\n",
    "        self.fc2 = nn.Linear(84, 50)\n",
    "        self.fc3 = nn.Linear(50,2)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        x = x.view(-1, 12288)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "simplenet = SimpleNet()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create an optimizer\n",
    "\n",
    "Here, we're just using Adam as our optimizer with a learning rate of 0.001."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = optim.Adam(simplenet.parameters(), lr=0.001)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Copy the model to GPU\n",
    "\n",
    "Copy the model to the GPU if available."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if torch.cuda.is_available():\n",
    "    device = torch.device(\"cuda\") \n",
    "else:\n",
    "    device = torch.device(\"cpu\")\n",
    "\n",
    "simplenet.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training \n",
    "\n",
    "Trains the model, copying batches to the GPU if required, calculating losses, optimizing the network and perform validation for each epoch."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(model, optimizer, loss_fn, train_loader, val_loader, epochs=20, device=\"cpu\"):\n",
    "    for epoch in range(epochs):\n",
    "        training_loss = 0.0\n",
    "        valid_loss = 0.0\n",
    "        model.train()\n",
    "        for batch in train_loader:\n",
    "            optimizer.zero_grad()\n",
    "            inputs, targets = batch\n",
    "            inputs = inputs.to(device)\n",
    "            targets = targets.to(device)\n",
    "            output = model(inputs)\n",
    "            loss = loss_fn(output, targets)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            training_loss += loss.data.item() * inputs.size(0)\n",
    "        training_loss /= len(train_loader.dataset)\n",
    "        \n",
    "        model.eval()\n",
    "        num_correct = 0 \n",
    "        num_examples = 0\n",
    "        for batch in val_loader:\n",
    "            inputs, targets = batch\n",
    "            inputs = inputs.to(device)\n",
    "            output = model(inputs)\n",
    "            targets = targets.to(device)\n",
    "            loss = loss_fn(output,targets) \n",
    "            valid_loss += loss.data.item() * inputs.size(0)\n",
    "            correct = torch.eq(torch.max(F.softmax(output), dim=1)[1], targets).view(-1)\n",
    "            num_correct += torch.sum(correct).item()\n",
    "            num_examples += correct.shape[0]\n",
    "        valid_loss /= len(val_loader.dataset)\n",
    "\n",
    "        print('Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}, accuracy = {:.2f}'.format(epoch, training_loss,\n",
    "        valid_loss, num_correct / num_examples))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train(simplenet, optimizer,torch.nn.CrossEntropyLoss(), train_data_loader,val_data_loader, epochs=5, device=device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Making predictions\n",
    "\n",
    "Labels are in alphanumeric order, so `cat` will be 0, `fish` will be 1. We'll need to transform the image and also make sure that the resulting tensor is copied to the appropriate device before applying our model to it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "labels = ['cat','fish']\n",
    "\n",
    "img = Image.open(\"./val/fish/100_1422.JPG\") \n",
    "img = img_transforms(img).to(device)\n",
    "\n",
    "\n",
    "prediction = F.softmax(simplenet(img))\n",
    "prediction = prediction.argmax()\n",
    "print(labels[prediction]) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Saving Models\n",
    "\n",
    "We can either save the entire model using `save` or just the parameters using `state_dict`. Using the latter is normally preferable, as it allows you to reuse parameters even if the model's structure changes (or apply parameters from one model to another)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(simplenet, \"/tmp/simplenet\") \n",
    "simplenet = torch.load(\"/tmp/simplenet\")    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(simplenet.state_dict(), \"/tmp/simplenet\")    \n",
    "simplenet = SimpleNet()\n",
    "simplenet_state_dict = torch.load(\"/tmp/simplenet\")\n",
    "simplenet.load_state_dict(simplenet_state_dict)   "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
